#include "global.h"
#include "program.h"
#include "benchmark.h"
#include "wam.h"
#include <iostream>

namespace lp {

	void print_statistics(const stats_tbl& samplev)
	{
		auto i =  samplev.begin();
		for (auto i = samplev.begin(); i != samplev.end(); i = samplev.upper_bound(i->first)) {
			std::cout << "Statistics for predicate " << fmap.get_data(i->first.first) << "/" << i->first.second << "\n";
			std::vector<stat> vec;
			std::for_each(i,samplev.upper_bound(i->first),[&](const std::pair<sign_t,lp::stat>& p){ vec.push_back(p.second); });
			print_statistics(vec.begin(),vec.end());
		}
	}


	Program benchmark(int count, const char* method, const lp::Program& kb)
	{
		using namespace std;
		using namespace lp;

		// Generalize?
		stats_tbl samplev;
		Program tmp;
		//const id_type method_id = lp::fmap.id(method);
		for (int k = 1; k <= count; ++k) {
			//std::cerr << "fmap size: " << fmap.size() << "\n";
			tmp = kb;
			std::vector<std::pair<sign_t,lp::stat>> stat = tmp.generalize(method);
			samplev.insert(stat.begin(),stat.end());
		}

		print_statistics(samplev);

		return tmp;
	}


	struct no_file {};
	typedef std::vector<std::pair<sign_t,clause>> fold_t;
	inline fold_t file2clauses(const boost::filesystem::path& datafile)
	{
		fold_t clv;
		if (boost::filesystem::exists(datafile) && boost::filesystem::is_regular_file(datafile)) {
			ifstream inf(datafile.string().c_str());
			if (!inf) {
				std::cerr << "Error: could not open file " << datafile << "\n";
				throw destruct_and_exit();
			} else {
				Program tmp;
				tmp.read(inf);
				for (auto i = tmp.con_begin(); i != tmp.con_end(); ++i) {
					for (const auto& ex : i->second) {
						clv.push_back(std::make_pair(i->first,ex));
					}
				}
				std::cout << "read file " << datafile << "\n";
			}
		} else {
			throw no_file();
		}
		return clv;
	}


	stats_tbl validate(const std::string& method, Program kb, const boost::filesystem::path& base_path)
	{
		//const std::string target_concept = base_path.filename().string();
		//std::cerr << "Target Concept: " << target_concept << "\n"; <-- THIS WILL BE A NUMBER
		stats_tbl stbl;
		if (boost::filesystem::exists(base_path / "1.pl") && boost::filesystem::is_regular_file(base_path / "1.pl")) {
			Program prog = kb;

			// Open shared folder if there is any
			if (consult_file(base_path / "shared.pl", prog, true)) {
				DEBUG_INFO(std::cerr << "Read shared.pl\n");
			}

			// How many files 1.pl, 2.pl, ...?
			int datasets = 0;
			for (int k = 1; ; ++k) {
				const auto datafile = base_path / (std::to_string(k) + ".pl");
				ifstream inf( datafile.string().c_str() );
				if (!inf) break;
				datasets = k;
			}

			DEBUG_INFO(std::cerr << "We have " << datasets << " datasets\n");

			// Cross validate
			for (int vk = 1; vk <= datasets; ++vk) {
				Program tmp = prog;
				// Add examples
				for (int tk = 1; tk <= datasets; ++tk) {
					if (tk == vk) continue;
					const auto datafile = base_path / (std::to_string(tk) + ".pl");
					if (!consult_file(datafile, tmp)) {
						std::cerr << "Error: could not read file " << datafile << "\n";
						throw destruct_and_exit();
					}
				}

				// Generalize
				std::cerr << "  generalizing " << vk << "\n";
				//generalize() already prints // DEBUG_INFO(std::cerr << tmp << "\n";);

				auto stats = tmp.generalize(method.c_str());

				// Measure Accuracy
				const auto datafile = base_path / (std::to_string(vk) + ".pl");
				Program valid;
				if (!consult_file(datafile, valid)) {
					std::cerr << "Error: could not read test file " << datafile << "\n";
					throw destruct_and_exit();
				}
				std::for_each(valid.con_begin(),valid.con_end(),[&](const std::pair<sign_t,std::vector<clause>>& p){
					const bool is_constraint = (p.first == sign_t(if_id,1));
					for (const auto& cl : p.second) {
						Functor q = (is_constraint ? Functor(decompile(cl)) : Functor(if_id,new Functor(decompile(p.first,cl))) );
						assert(q.sign(if_id,1));
						//std::cerr << "Coverage query: " << q << "\n";
						clause qcl = compile(q);
						const sign_t sign = q.arg_first()->signature();
						auto at = std::find_if(stats.begin(),stats.end(),[&](const std::pair<sign_t,lp::stat>& p2){ return p2.first == sign; });
						if (at == stats.end()) {
							std::cerr << "ERROR: looking for signature " << Functor::get_data(sign.first) << "/" << sign.second << " among:\n";
							for (const auto& p2 : stats) {
								std::cerr << Functor::get_data(p2.first.first) << "/" << p2.first.second << "\n";
							}
							std::cerr << "You probably specified the wrong target predicate in a modeh declaration\n";
							std::cerr << "Was this a constraint? " << is_constraint << "\n";
							std::cerr << "q: " << q << "\n";
							assert(false);
							exit(1);
						}
						// Increase counters (update accuracy)
						const bool is_cov = tmp.covers(qcl);
						if (!is_constraint) {
							// Positive example
							if (is_cov) {
								++(at->second[stat::true_positive]);
							} else {
								//std::cerr << "Error, covered false negative: " << q << "\n";
								++(at->second[stat::false_negative]);
							}
						} else {
							// Negative examples
							if (is_cov) {
								//std::cerr << "Error, covered false positive: " << q << "\n";
								++(at->second[stat::false_positive]);
							} else {
								++(at->second[stat::true_negative]);
							}
						}
					}
				});

				// Add stats
				stbl.insert(stats.begin(),stats.end());
				DEBUG_INFO(std::cerr << "generalizing done " << stbl.size() << "\n");
			}
		} else {
			std::cerr << "Error: no file 1.pl in folder " << base_path << "\n";
			throw destruct_and_exit();
		} 

		return stbl;
	}

	stats_tbl validate_rec(const std::string& method, Program kb, const boost::filesystem::path& base_path)
	{
		stats_tbl stbl;

		if (!boost::filesystem::is_directory(base_path)) {
			std::cerr << "Error: path is not directory: " << base_path << "\n";
			throw destruct_and_exit();
		}

		// Is there a global file?
		bool has_global = false;
		if (consult_file(base_path / "global.pl", kb, true)) {
			DEBUG_BASIC( std::cout << "Loaded global.pl\n" );
			has_global = true;
		}
		const std::string file = base_path.filename().string() + ".pl";
		if (file.size() > 3 && consult_file(base_path / file, kb, true)) {
			DEBUG_BASIC( std::cout << "Loaded " << file << "\n" );
			has_global = true;
		}
		if (!has_global) {
			std::cerr << "Error: no file global.pl or " << file << " (if no global definitions are needed, leave an empty file 'global.h')\n";
			throw destruct_and_exit();
		}

		// Validate here if we have a "1.pl", otherwise look for folder "1"
		if (boost::filesystem::exists(base_path / "1") && boost::filesystem::is_directory(base_path / "1")) {
			if (boost::filesystem::exists(base_path / "1.pl")) {
				std::cerr << "Error: both directory 1 and file 1.pl exist\n";
				throw destruct_and_exit();
			}
			for (int k = 1; ; ++k) {
				const auto path = base_path / std::to_string(k);
				if (boost::filesystem::exists(path) && boost::filesystem::is_directory(path)) {
					std::cerr << "Processing directory " << k << "\n";
					auto res = validate(method,kb,path);
					stbl.insert(res.begin(),res.end());
				} else {
					std::cerr << "Processed " << (k-1) << " directories\n";
					break;
				}
			}
		} else {
			// Try to validate this folder
			auto res = validate(method,kb,base_path);
			stbl.insert(res.begin(),res.end());
		}

		decltype(stbl.begin()) iend;
		for (auto i = stbl.begin(); i != stbl.end(); i = iend) {
			iend = stbl.upper_bound(i->first);
			std::cout << "---------------------------------------------------------\n";
			std::cout << "Statistics for " << Functor::get_data(i->first.first) << "/" << i->first.second << " [" << method << "]\n";
			auto get_second = [&](const std::pair<const sign_t,lp::stat>& p){ return p.second; };
			print_statistics(
				boost::make_transform_iterator(i,get_second),
				boost::make_transform_iterator(iend,get_second));
			std::cout << "---------------------------------------------------------\n";
		}

		return stbl;
	}


	// Validate without using cross validation files
	void auto_cross_validation(
		int count, 
		int kfold,
		const std::string& method_mask,
		const lp::Program& kb)
	{
		std::vector<validation_algorithm> methods;

		if (!std::all_of(method_mask.begin(),method_mask.end(),[&](char ch){ return std::isdigit(ch); })) {
			std::cerr << "ERROR: methods during auto cross validation are specified as booleans: [NrSample][Emulated][Heuristic][Enumerate]\n";
			throw destruct_and_exit();
		}

		try {
			if (method_mask.at(0) == '1') {
				methods.push_back(validation_algorithm("nrsample",kb.params));
			} else if (method_mask.at(0) == '3') {
				//p[parameters::detect_dependent] = p[parameters::detect_reflexivity] = p[parameters::detect_symmetry] = 1LL;

				parameters p = kb.params;
				p[parameters::detect_dependent] = 0LL;
				p[parameters::pos_order] = "sequential";
				methods.push_back(validation_algorithm("nrsample",p));

				p = kb.params;
				p[parameters::detect_dependent] = 1LL;
				p[parameters::pos_order] = "sequential";
				methods.push_back(validation_algorithm("nrsample",p));

				//p = kb.params;
				//p[parameters::detect_symmetry] = 1LL;
				//p[parameters::pos_order] = "regularity";
				//methods.push_back(validation_algorithm("nrsample",p));
			}
			if (method_mask.at(1) == '1')
				methods.push_back(validation_algorithm("emulate_nrsample",kb.params));
			if (method_mask.at(2) == '1')
				methods.push_back(validation_algorithm("heuristic",kb.params));
			if (method_mask.at(3) == '1')
				methods.push_back(validation_algorithm("enumerate",kb.params));
		} catch (std::out_of_range) {}
		for (int i = 0; i < count; ++i) {
			std::cerr << "Cross Validating, iteration " << (i+1) << "...\n";
			auto_cross_validation(kfold,methods,kb);
		}
		for (auto& p : methods) {
			std::cerr << "\nStatistics for " << p.method << "\n";
			print_statistics(p.result);
		}
	}


	inline void shuffle(std::list<clause>& li)
	{
		std::vector<clause> tmp;
		tmp.reserve(li.size());
		for (auto&& cl : li) {
			tmp.push_back(std::move(cl));
		}
		std::random_shuffle(tmp.begin(),tmp.end());
		li.clear();
		for (auto&& cl : tmp) li.push_back(std::move(cl));
	}

	typedef std::vector<std::pair<std::list<clause>,std::list<clause>>> partitions;
	partitions make_folds(
		std::list<clause> pex,
		std::list<clause> nex,
		int kfold)
	{
		// Implement kfold cross validation by extracting examples from kb_orig
		partitions valid;
		if (kfold <= 0 || kfold > int(pex.size())) {
			std::cerr << "Error: Cannot crossvalidate with " << kfold << "-fold and " << pex.size() << " examples\n";
			throw destruct_and_exit();
		}

		//std::cerr << "Kfold: " << kfold << "\n";
		// Shuffle examples
		shuffle(pex);
		shuffle(nex);

		//std::cerr << "Pex size: " << pex.size() << "\n";
		// Each fold has round_down(pex.size() / kfold) elements
		const int pos_fold_size = int(pex.size()) / kfold;
		const int neg_fold_size = int(nex.size()) / kfold;
		//std::cerr << "Fold size: " << pos_fold_size << "/" << neg_fold_size << "\n";
		// Split 
		valid.reserve(kfold);
		for (int k = 0; k < kfold; ++k) {
			std::list<clause> pos,neg;

			auto j = pex.begin();
			std::advance(j,pos_fold_size);
			pos.splice(pos.begin(),pex,pex.begin(),j);

			auto jn = nex.begin();
			std::advance(jn,neg_fold_size);
			neg.splice(neg.begin(),nex,nex.begin(),jn);
			//std::cerr << "Fold " << (k+1) << " contains " << pos.size() << "+" << neg.size() << " elements\n";

			valid.push_back(std::make_pair(std::move(pos),std::move(neg)));
		}
		assert(valid.size() == kfold && 
			std::all_of(valid.begin(),valid.end(),[&](const std::pair<std::list<clause>,std::list<clause>>& l){ 
				return l.first.size() == pos_fold_size && l.second.size() == neg_fold_size; }));
		return valid;
	}



	void auto_cross_validation(
		int kfold, 
		std::vector<validation_algorithm>& methods,
		lp::Program kb)
	{
		using namespace lp;
		stats_tbl samplev;

		// Get signatures from all modeh's
		std::set<sign_t> signs;
		std::for_each(kb.mode_begin(),kb.mode_end(),[&](const Mode& m){ if (m.is_head()) signs.insert(m.atom().signature()); });
		if (signs.size() != 1) {
			std::cerr << "ERROR: only one modeh declaration allowed when using automatic cross validation\n";
			assert(false);
			exit(1);
		}
		const sign_t& sign = *signs.begin();

		// Strip kb from examples
		// Make table of randomly shuffled positive examples for each signature
		partitions parts;
		std::list<clause> posv,negv;
		//std::cerr << "Sign: " << Functor::get_data(sign.first) << "/" << sign.second << "\n";
		kb.examples_to_queries(sign,std::back_inserter(posv),std::back_inserter(negv),true);
		//for (auto& cl : posv) {
		//	std::cerr << "posv: " << decompile(cl) << "\n";
		//}
		// If kfold is set to 1, use leave-one-out
		if (kfold == 1) {
			kfold = std::min(posv.size(),negv.size());
			DEBUG_WARNING(std::cerr << "Setting kfold to leave-one-out: " << kfold << " = min(" << posv.size() << "," << negv.size() << ")\n");
		}
		parts = make_folds(posv,negv,kfold);
		//postbl.insert(std::make_pair(sign,std::move(folds)));
		//std::cerr << "KB without examples: " << kb << "\n";
		//std::cerr << "postbl size: " << postbl.size() << "\n";

		// Cross validate
		assert(!postbl.empty());
		std::vector<double> accuracy;

		// valid[k] is validation data, the rest are used as examples
		// Make Knowledge Base using parts[k] for verification
		//std::cerr << "partitions: " << parts.size() << "\n";
		for (int k = 0; k < int(parts.size()); ++k) {
			lp::Program tmp = kb;
			for (int i = 0; i < int(parts.size()); ++i) {
				if (i == k) continue;
				//std::cerr << "parts[" << i << "].first.size() = " << parts[i].first.size() << "\n";
				std::for_each(parts[i].first.begin(),parts[i].first.end(),[&](const clause& cl){ 
					Functor pe = *decompile(cl).arg_first();
					//std::cerr << "Adding example: " << pe << "\n";
					tmp.push_back( std::move(pe) );
				});
				std::for_each(parts[i].second.begin(),parts[i].second.end(),[&](const clause& cl){ tmp.push_back(cl); });
			}
			// Knowledge base tmp built
			//std::cerr << "Created test: " << tmp << "\n";
			// For each method, cross validate
			for (auto& p : methods) {
				//std::cerr << "  cross Validating, method: " << p.method << " ...\n";
				Program prog = tmp;
				prog.params = p.settings;
				auto stats = prog.generalize(p.method.c_str());

				// Increase counters (accuracy)
				// Find stats
				auto at = std::find_if(stats.begin(),stats.end(),[&](const std::pair<sign_t,lp::stat>& p){ return p.first == sign; });
				if (at == stats.end()) {
					std::cerr << "ERROR: stats could not be found\n";
					assert(false);
					exit(1);
				}
				auto& sts = at->second;
				// Test positive coverage
				int tp = 0,tn = 0,fp = 0,fn = 0;

				for (auto& cl : parts[k].first) {
					const bool is_cov = prog.covers(cl);
					// Positive example
					if (is_cov) {
						++tp;
						++sts[stat::true_positive];
					} else {
						//std::cerr << "Error, covered false negative: " << q << "\n";
						++fn;
						++sts[stat::false_negative];
					}
				}
				// Test negative coverage
				for (auto& cl : parts[k].second) {
					const bool is_cov = prog.covers(cl);
					if (is_cov) {
						++fp;
						++sts[stat::false_positive];
					} else {
						//std::cerr << "Error, covered false negative: " << q << "\n";
						++tn;
						++sts[stat::true_negative];
					}
				}

				DEBUG_BASIC(std::cout << "Accuracy: " << double(tp+tn) / double(tp+tn+fp+fn) << " = (" << tp << "+" << tn << ")/(" << tp << "+" << tn << "+" << fp << "+" << fn << ")\n");

				p.result.insert(std::make_pair(sign,sts));
			} // for each method
		}

	}




} // namespace lp


